import time
import logging
import os
import random
import torch
import torch.utils.data
from diff_utils.helpers import *
import pytorch3d.transforms as pytorch3d
import copy

import pandas as pd 
import numpy as np
import csv, json

from tqdm import tqdm

class ModulationLoader(torch.utils.data.Dataset):
    def __init__(self, data_path, pc_path=None, split_file=None, pc_size=None):
        super().__init__()

        self.conditional = pc_path is not None 

        if self.conditional:
            self.modulations, self.pc_paths = self.load_modulations(data_path, pc_path, split_file)
            print("data shape, dataset len: ", self.modulations[0].shape, len(self.modulations))
            

            assert len(self.pc_paths) == len(self.modulations)
        else:
            self.modulations = self.unconditional_load_modulations(data_path, split_file)
        
        
        
    def __len__(self):
        return len(self.modulations)

    def __getitem__(self, index):
        data_dict = {}
        if self.conditional:
            data = np.load(self.pc_paths[index])
            PC_da = torch.from_numpy(data['camera_partial_pc'])
            gt_R_da = torch.from_numpy(data['base_pose'][:3, :3])
            gt_t_da = torch.from_numpy(data['base_pose'][:3, 3])  
            gt_xyz = torch.from_numpy(data['joint_xyz'][1:])
            gt_rpy = torch.from_numpy(data['joint_rpy'][1:])
            gt_seg = torch.from_numpy(data['cls'])
            rot = pytorch3d.matrix_to_rotation_6d(gt_R_da.permute(1, 0))
            location = gt_t_da  
            data_dict['gt_pose'] = torch.cat([rot.float(), location.float()], dim=-1)  
            data_dict['pts'] = PC_da
            data_dict['gt_xyz'] = gt_xyz
            data_dict['gt_rpy'] = gt_rpy
            data_dict['seg'] = gt_seg
            """ zero center """
            num_pts = data_dict['pts'].shape[0]
            zero_mean = torch.mean(data_dict['pts'][:, :3], dim=0)
            data_dict['zero_mean_pts'] = copy.deepcopy(data_dict['pts'])
            data_dict['zero_mean_pts'][:, :3] -= zero_mean.unsqueeze(0).repeat(num_pts, 1)
            data_dict['zero_mean_gt_pose'] = copy.deepcopy(data_dict['gt_pose'])
            data_dict['zero_mean_gt_pose'][-3:] -= zero_mean
            data_dict['zero_mean_gt_xyz'] = copy.deepcopy(data_dict['gt_xyz'])
            data_dict['zero_mean_gt_xyz'] -= zero_mean
            data_dict['zero_mean_gt_rpy'] = copy.deepcopy(data_dict['gt_rpy'])
            data_dict['zero_mean_gt_rpy'] -= zero_mean
            data_dict['pts_center'] = zero_mean

            data_dict['zero_mean_gt_xyz'] = data_dict['zero_mean_gt_xyz'].view(-1)
            data_dict['zero_mean_gt_rpy'] = data_dict['zero_mean_gt_rpy'].view(-1)

            
            data_dict['point_cloud'] = torch.from_numpy(data['canonical_partial_pc']).float()
            data_dict['latent'] = self.modulations[index]

        else:
            data_dict['point_cloud'] = False
            data_dict['latent'] = self.modulations[index]

        return data_dict
        

    def load_modulations(self, data_source, pc_source, split, f_name="latent.txt", add_flip_augment=False, return_filepaths=True):
        files = []
        filepaths = [] 
        for dataset in split: 
            dataset = dataset.replace(".json", "")
            class_name = data_source.split('/')[1]
            if add_flip_augment:
                for idx in range(4):
                    instance_filename = os.path.join(data_source, "latent_{}.txt".format(idx))
                    if not os.path.isfile(instance_filename):
                        print("Requested non-existent file '{}'".format(instance_filename))
                        continue
                    with open(instance_filename, 'r') as f:
                        latent_data = np.loadtxt(f)
                    files.append(torch.from_numpy(latent_data).float())
                filepaths.append( os.path.join(pc_source, dataset, "sdf_data.csv") )

            else:
                instance_filename = os.path.join(data_source, class_name, dataset, f_name)
                if not os.path.isfile(instance_filename):
                    
                    continue
                files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )
                filepaths.append( os.path.join(pc_source, f"{dataset}.npz") )
        if return_filepaths:
            return files, filepaths
        return files

    def unconditional_load_modulations(self, data_source, split, f_name="latent.txt", add_flip_augment=False):
        files = []
        for dataset in split: 
            dataset = dataset.replace(".json", "")
            class_name = data_source.split('/')[1]
            if add_flip_augment:
                for idx in range(4):
                    instance_filename = os.path.join(data_source, class_name, "latent_{}.txt".format(idx))
                    if not os.path.isfile(instance_filename):
                        print("Requested non-existent file '{}'".format(instance_filename))
                        continue
                    files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )

            else:
                instance_filename = os.path.join(data_source, class_name, dataset, f_name)
                if not os.path.isfile(instance_filename):
                    continue
                files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )
        return files